K-means 是一種簡單且廣泛使用的聚類(Clustering)算法,目標是將一組數據劃分為 K 個不同的群體,使得每個數據點屬於與其最接近的群體。
初始化:隨機選擇 K 個數據點作為初始的群體中心(centroid)。
分配:對於每個數據點,計算其與每個群體中心的距離,並將其分配到最近的群體中心所在的群體中。
更新中心:對於每個群體,計算其所有成員的平均值,並將這個平均值作為新的群體中心。
重複:重複步驟 2 和 3,直到群體中心不再發生變化或達到一個預先定義的停止條件(如最大迭代次數)。
收斂:一旦群體中心不再發生變化,算法收斂,並且每個數據點都被分配到了一個群體中。
from sklearn.datasets import load_digits
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
# 載入手寫數字數據集
digits = load_digits()
X = digits.data
y = digits.target
# 初始化 K-means 模型,指定要分為幾個群體(這裡指定為 10,因為我們要分類 0 到 9 的數字)
kmeans = KMeans(n_clusters=10, random_state=42)
# 進行 K-means 聚類
kmeans.fit(X)
# 獲取聚類結果
labels = kmeans.labels_
# 繪製聚類中心(這裡展示了每個群體的平均數字)
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(kmeans.cluster_centers_[i].reshape(8, 8), cmap='binary')
ax.set_title(f'Cluster {i}')
ax.axis('off')
plt.show()